File size: 3,847 Bytes
5806e12 a7b67d5 5806e12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# Segmentation function from Batchalign
import json
import os
import re
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from nltk.tokenize import sent_tokenize
import nltk
nltk.download('punkt_tab')
nltk.download('punkt')
# input is the list of words, no punctuation, all lower case,
# output is the list of label: 0 represent the correspounding word is not the last word of c-unit,
# 1 represent the correspounding word is the last word of c-unit
def segment_batchalign(text: str) -> list[int]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load tokenizer and model locally
model_path = "talkbank/CHATUtterance-en"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForTokenClassification.from_pretrained(model_path)
model.to(DEVICE)
model.eval()
text = text.lower().replace(".", "").replace(",", "")
words = text.split()
# Tokenize
tokd = tokenizer([words], return_tensors="pt", is_split_into_words=True).to(DEVICE)
with torch.no_grad():
logits = model(**tokd).logits
predictions = torch.argmax(logits, dim=2).squeeze(0).cpu().tolist()
# Align predictions with words
word_ids = tokd.word_ids(0)
result_words = []
seen = set()
for i, word_idx in enumerate(word_ids):
if word_idx is None or word_idx in seen:
continue
seen.add(word_idx)
pred = predictions[i]
word = words[word_idx]
if pred == 1:
word = word[0].upper() + word[1:]
elif pred == 2:
word += "."
elif pred == 3:
word += "?"
elif pred == 4:
word += "!"
elif pred == 5:
word += ","
result_words.append(word)
# Convert tokens back to string and split into sentences
sentence = tokenizer.convert_tokens_to_string(result_words)
try:
sentences = sent_tokenize(sentence)
except LookupError:
import nltk
nltk.download('punkt')
sentences = sent_tokenize(sentence)
# Convert sentences to boundary labels
boundaries = []
for sent in sentences:
sent_word_count = len(sent.split())
boundaries += [0] * (sent_word_count - 1) + [1]
for i in range(1, len(boundaries)):
if boundaries[i - 1] == 1 and boundaries[i] == 1:
boundaries[i - 1] = 0
return boundaries
if __name__ == "__main__":
# Test the segmentation
# test_text = "once a horse met elephant and then they saw a ball in a pool and then the horse tried to swim and get the ball they might be the same but they are doing something what do you think they are doing"
test_text = "sir can I have balloon and the sir say yes you can and he said five dollars that xxx and and he is like where is that they his tether is right there and and he said and the bunny said oopsies I do not have money and the doc and the and the and the bunny runned for the doctor an and he says doctor doctor I want a balloon here is the money and you can have the balloons both of them now they are happy the end"
print(f"Input text: {test_text}")
print(f"Words: {test_text.split()}")
labels = segment_batchalign(test_text)
print(f"Segment labels: {labels}")
# Show segmented text
words = test_text.split()
segments = []
current_segment = []
for word, label in zip(words, labels):
current_segment.append(word)
if label == 1:
segments.append(" ".join(current_segment))
current_segment = []
# Add remaining words if any
if current_segment:
segments.append(" ".join(current_segment))
print("\nSegmented text:")
for i, segment in enumerate(segments, 1):
print(f"Segment {i}: {segment}") |