KSSDS: Korean Sentence Splitter for Dialogue Systems
KSSDS๋ lcw99/t5-base-korean-text-summary์ ์ธ์ฝ๋๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ์ฌ,
AI Hub ๋ฐ์ดํฐ๋ฅผ ์ด์ฉํด fine-tuningํ ๋ชจ๋ธ์
๋๋ค.
์ด ๋ชจ๋ธ์ ํ๊ตญ์ด ๋ํ ์์คํ
์ฉ ๋ฌธ์ฅ ๋ถ๋ฆฌ๊ธฐ๋ก,
Whisper์ ๊ฐ์ STT ๋ชจ๋ธ์ด ์์ฑํ ํ๊ตญ์ด ํ
์คํธ๋ฅผ ๋ฌธ์ฅ ๋จ์๋ก ๋ถ๋ฆฌํ๋ ๊ฒ์ ๋ชฉํ๋ก ๋ง๋ค์ด์ก์ต๋๋ค.
์์ธํ ์ค๋ช ์ KSSDS GitHub repository๋ฅผ ์ฐธ๊ณ ํด์ฃผ์ธ์.
์ฌ์ฉ ๋ฐฉ๋ฒ
1. PyPI ๋๋ GitHub ์ค์น๋ฅผ ํตํ ์ฌ์ฉ (๊ถ์ฅ)
- PyPI ๋๋ GitHub๋ฅผ ํตํด ์ค์นํ๋ฉด ๋์ฑ ๋ ํธ๋ฆฌํ๊ฒ ์ฌ์ฉํ์ค ์ ์์ต๋๋ค.
- KSSDS GitHub repository์์ ์ค์ ๋ฐ ์ฌ์ฉ ๋ฐฉ๋ฒ์ ํ์ธํ์ธ์.
2. Hugging Face Hub ๋ค์ดํฐ๋ธ ๋ฐฉ์ ์ฌ์ฉ
Hugging Face Hub์์ ๋ชจ๋ธ๊ณผ T5 ์ธ์ฝ๋๋ฅผ ์ง์ ๋ค์ด๋ก๋ํ์ฌ ์ฌ์ฉํ ์ ์์ต๋๋ค.
์๋๋ ๋ฌธ์ฅ ๋ถ๋ฆฌ๋ฅผ ์ํํ๋ ์ ์ฒด ํ์ดํ๋ผ์ธ์ ์์ ์ฝ๋์
๋๋ค.
# ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋ฐ ๋ชจ๋ ๋ถ๋ฌ์ค๊ธฐ
from huggingface_hub import hf_hub_download # Hugging Face Hub์์ ํ์ผ ๋ค์ด๋ก๋
import sys
import os
from transformers import AutoTokenizer # Hugging Face Tokenizer
import torch
from typing import List
# Hugging Face Hub์์ T5_encoder.py์ ๋ชจ๋ธ ๋ค์ด๋ก๋
model_name = "ggomarobot/KSSDS"
file_path = hf_hub_download(repo_id=model_name, filename="T5_encoder.py") # T5_encoder.py ํ์ผ ๋ค์ด๋ก๋
module_dir = os.path.dirname(file_path) # T5_encoder.py ํ์ผ์ด ์๋ ๋๋ ํฐ๋ฆฌ ๊ฒฝ๋ก
sys.path.append(module_dir) # Python ๊ฒฝ๋ก์ T5_encoder ๋๋ ํฐ๋ฆฌ ์ถ๊ฐ
from T5_encoder import T5ForTokenClassification # ์ปค์คํ
T5 ์ธ์ฝ๋ ๋ชจ๋ธ
tokenizer = AutoTokenizer.from_pretrained(model_name) # ํ ํฌ๋์ด์ ๋ถ๋ฌ์ค๊ธฐ
model = T5ForTokenClassification.from_pretrained(model_name) # ์ปค์คํ
T5 ์ธ์ฝ๋ ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move the model to GPU
model = model.to(device)
model.eval() # ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์
# ์ต๋ ์
๋ ฅ ๊ธธ์ด ์ค์
max_length = 512
def preprocess_text(input_text: str) -> List[dict]:
"""
๊ธด ์
๋ ฅ ํ
์คํธ๋ฅผ ์ฒ๋ฆฌ ๊ฐ๋ฅํ ์ฒญํฌ๋ก ๋๋๋ ํจ์.
Args:
input_text (str): ์ฒ๋ฆฌํ ๊ธด ์
๋ ฅ ํ
์คํธ.
Returns:
List[dict]: ์ฒญํฌ ๋จ์์ input_ids์ attention_mask๋ฅผ ํฌํจํ ๋ฆฌ์คํธ.
"""
tokenized_sequence = tokenizer.encode(input_text, add_special_tokens=False)
chunks = [tokenized_sequence[i:i + max_length] for i in range(0, len(tokenized_sequence), max_length)]
processed_data = [{"input_ids": chunk, "attention_mask": [1] * len(chunk)} for chunk in chunks]
return processed_data
def handle_repetitions(sentences: List[str], max_repeats: int = 60, detection_threshold: int = 70, max_phrase_length: int = 2) -> List[str]:
"""
Handles single-word and phrase repetitions in a list of sentences, ensuring proper order and separation.
Args:
sentences (List[str]): List of input sentences to process.
max_repeats (int): Maximum number of phrase repetitions to allow before splitting.
detection_threshold (int): Minimum length of text for repetition detection.
max_phrase_length (int): Maximum length of a phrase to consider for repetition detection.
Returns:
List[str]: The processed text split into sentences.
"""
processed_sentences = []
for sentence in sentences:
words = sentence.split()
if len(words) <= detection_threshold:
processed_sentences.append(sentence)
continue
result_sentences = []
current_sentence = []
current_repetition = []
def flush_sentence():
"""Flush the current sentence into result_sentences."""
if current_sentence:
result_sentences.append(" ".join(current_sentence))
current_sentence.clear()
def flush_repetition(phrase_length):
"""Flush the current repetition into result_sentences."""
for i in range(0, len(current_repetition), phrase_length * max_repeats):
chunk = current_repetition[i:i + phrase_length * max_repeats]
result_sentences.append(" ".join(chunk))
current_repetition.clear()
def find_repeating_phrase(start_idx):
"""Find the smallest repeating phrase starting at the given index."""
for phrase_length in range(1, max_phrase_length + 1):
phrase = words[start_idx:start_idx + phrase_length]
next_idx = start_idx + phrase_length
if next_idx + phrase_length <= len(words) and words[next_idx:next_idx + phrase_length] == phrase:
return phrase
return None
i = 0
while i < len(words):
repeating_phrase = find_repeating_phrase(i)
if repeating_phrase:
# Flush any ongoing sentence before handling repetition
flush_sentence()
# Accumulate repeating phrases
phrase_length = len(repeating_phrase)
while i + phrase_length <= len(words) and words[i:i + phrase_length] == repeating_phrase:
current_repetition.extend(repeating_phrase)
i += phrase_length
# Flush accumulated repetition if it reaches the threshold
if len(current_repetition) >= phrase_length * max_repeats:
flush_repetition(phrase_length)
else:
# Add non-repeating words to the current sentence
if current_repetition:
# Flush repetition before starting a new sentence
flush_repetition(1) # Default to single-word repetition
current_sentence.append(words[i])
i += 1
# Flush any remaining tokens
flush_sentence()
flush_repetition(1) # Default to single-word repetition for the last chunk
processed_sentences.extend(result_sentences)
return processed_sentences
def segment_predictions(input_ids: List[int], predictions: List[int]) -> List[List[int]]:
"""
Segment predictions into sentences based on label 1 (sentence-ending).
Args:
input_ids (List[int]): List of input token IDs.
predictions (List[int]): Corresponding prediction labels.
Returns:
List[List[int]]: Segmented sentences as lists of token IDs.
"""
segments = []
current_segment = []
for token, label in zip(input_ids, predictions):
if label == 1: # Sentence-ending label
if current_segment:
current_segment.append(token)
segments.append(current_segment)
current_segment = []
else:
segments.append([token])
else:
current_segment.append(token)
if current_segment: # Append any remaining tokens
segments.append(current_segment)
return segments
def decode_predictions(input_ids: List[int], predictions: List[int], tokenizer, carry_over=None):
"""
Decode model predictions into sentences, handling carry-over tokens across chunks.
Args:
input_ids (List[int]): Input token IDs.
predictions (List[int]): Prediction labels.
tokenizer: Hugging Face tokenizer instance.
carry_over (List[int], optional): Tokens carried over from the previous chunk.
Returns:
Tuple[List[str], List[int]]: Decoded sentences and remaining carry-over tokens.
"""
if carry_over is None:
carry_over = []
sentences = []
tokens = carry_over + input_ids # Include carry-over tokens
labels = [0] * len(carry_over) + predictions # Carry-over tokens have label 0
segmented = segment_predictions(tokens, labels)
for segment in segmented[:-1]: # Decode all segments except the last one
sentence = tokenizer.decode(segment, skip_special_tokens=False, clean_up_tokenization_spaces=False).strip()
if sentence:
sentences.append(sentence)
# Handle carry-over for the last segment
carry_over = segmented[-1] if segmented[-1] and labels[len(tokens) - len(segmented[-1])] != 1 else []
return sentences, carry_over
def inference(input_text: str) -> List[str]:
"""
Perform sentence splitting on input text using the HF Hub model.
Args:
input_text (str): Input text to split into sentences.
Returns:
List[str]: List of split sentences.
"""
chunks = preprocess_text(input_text)
carry_over_tokens = []
sentences = []
for chunk in chunks:
# Prepare inputs for the model
input_ids_tensor = torch.tensor([chunk["input_ids"]], dtype=torch.long, device=device)
attention_mask_tensor = torch.tensor([chunk["attention_mask"]], dtype=torch.long, device=device)
with torch.no_grad():
outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)
predictions = outputs.logits.argmax(dim=-1).squeeze().tolist()
# Ensure predictions is a list
if isinstance(predictions, int):
predictions = [predictions]
input_ids = chunk["input_ids"]
decoded_sentences, carry_over_tokens = decode_predictions(input_ids, predictions, tokenizer, carry_over_tokens)
sentences.extend(decoded_sentences)
# Process any remaining carry-over tokens
if carry_over_tokens:
remaining_sentence = tokenizer.decode(carry_over_tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False).strip()
if remaining_sentence:
sentences.append(remaining_sentence)
return handle_repetitions(sentences)
# ์ฌ์ฉ ์์
input_text = "์๋
ํ์ธ์. ์ค๋ ๋ ์จ๊ฐ ์ฐธ ์ข๋ค์. ์ ๋ ์ฐ์ฑ
์ ๋๊ฐ ์์ ์
๋๋ค."
split_sentences = inference(input_text)
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
for idx, sentence in enumerate(split_sentences):
print(f"{idx + 1}: {sentence}")
KSSDS_NO_LF ๋ชจ๋ธ ์๊ฐ
KSSDS_NO_LF๋ ablation study๋ฅผ ์ํด Length Filter๋ฅผ ์ ์ฉํ์ง ์๊ณ ํ์ต๋ ๋ชจ๋ธ์
๋๋ค.
์ด ๋ชจ๋ธ์ ๊ธด ํ
์คํธ๋ฅผ ์ฒ๋ฆฌํ ๋ ์๋์ ์ผ๋ก ๋ ์ ๋ฐํ ๋ฌธ์ฅ ๋ถ๋ฆฌ๋ฅผ ์ํํฉ๋๋ค.
์ฌ์ฉ ๋ฐฉ์์ ๋ณธ ๋ชจ๋ธ๊ณผ ๋์ผํ์ง๋ง
์ฐ๊ตฌ ๋ฐ ๋น๊ต ๋ชฉ์ ์ธ์๋ ์ค์ ์ฌ์ฉ์ ์ ํฉํ์ง ์์ ์ ์์ต๋๋ค.
์ฌ์ฉ ๋ฐฉ๋ฒ
์์ Hugging Face Hub ๋ค์ดํฐ๋ธ ๋ฐฉ์ ์ฌ์ฉ ์ฝ๋ ์ค๋ํซ์์ model_name
์ "ggomarobot/KSSDS_NO_LF"
๋ก ๊ต์ฒดํ์ฌ ์ฌ์ฉํ ์ ์์ต๋๋ค:
# Load tokenizer and model
model_name = "ggomarobot/KSSDS_NO_LF"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForTokenClassification.from_pretrained(model_name)
- Downloads last month
- 14
Model tree for ggomarobot/KSSDS
Base model
lcw99/t5-base-korean-text-summary