File size: 3,847 Bytes
5806e12
 
 
 
 
 
 
 
a7b67d5
 
 
 
5806e12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Segmentation function from Batchalign
import json
import os
import re
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from nltk.tokenize import sent_tokenize

import nltk
nltk.download('punkt_tab')
nltk.download('punkt')

# input is the list of words, no punctuation, all lower case, 
# output is the list of label: 0 represent the correspounding word is not the last word of c-unit,
# 1 represent the correspounding word is the last word of c-unit
def segment_batchalign(text: str) -> list[int]:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load tokenizer and model locally
    model_path = "talkbank/CHATUtterance-en"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForTokenClassification.from_pretrained(model_path)
    model.to(DEVICE)
    model.eval()

    text = text.lower().replace(".", "").replace(",", "")
    words = text.split()

    # Tokenize
    tokd = tokenizer([words], return_tensors="pt", is_split_into_words=True).to(DEVICE)
    with torch.no_grad():
        logits = model(**tokd).logits
    predictions = torch.argmax(logits, dim=2).squeeze(0).cpu().tolist()

    # Align predictions with words
    word_ids = tokd.word_ids(0)
    result_words = []
    seen = set()

    for i, word_idx in enumerate(word_ids):
        if word_idx is None or word_idx in seen:
            continue
        seen.add(word_idx)

        pred = predictions[i]
        word = words[word_idx]

        if pred == 1:
            word = word[0].upper() + word[1:]
        elif pred == 2:
            word += "."
        elif pred == 3:
            word += "?"
        elif pred == 4:
            word += "!"
        elif pred == 5:
            word += ","

        result_words.append(word)

    # Convert tokens back to string and split into sentences
    sentence = tokenizer.convert_tokens_to_string(result_words)
    try:
        sentences = sent_tokenize(sentence)
    except LookupError:
        import nltk
        nltk.download('punkt')
        sentences = sent_tokenize(sentence)

    # Convert sentences to boundary labels
    boundaries = []
    for sent in sentences:
        sent_word_count = len(sent.split())
        boundaries += [0] * (sent_word_count - 1) + [1]

    for i in range(1, len(boundaries)):
        if boundaries[i - 1] == 1 and boundaries[i] == 1:
            boundaries[i - 1] = 0

    return boundaries







if __name__ == "__main__":
    # Test the segmentation
    # test_text = "once a horse met elephant and then they saw a ball in a pool and then the horse tried to swim and get the ball they might be the same but they are doing something what do you think they are doing"
    test_text = "sir can I have balloon and the sir say yes you can and he said five dollars that xxx and and he is like where is that they his tether is right there and and he said and the bunny said oopsies I do not have money and the doc and the and the and the bunny runned for the doctor an and he says doctor doctor I want a balloon here is the money and you can have the balloons both of them now they are happy the end"
    print(f"Input text: {test_text}")
    print(f"Words: {test_text.split()}")
    
    labels = segment_batchalign(test_text) 
    print(f"Segment labels: {labels}")
    
    # Show segmented text
    words = test_text.split()
    segments = []
    current_segment = []
    
    for word, label in zip(words, labels):
        current_segment.append(word)
        if label == 1:
            segments.append(" ".join(current_segment))
            current_segment = []
    
    # Add remaining words if any
    if current_segment:
        segments.append(" ".join(current_segment))
    
    print("\nSegmented text:")
    for i, segment in enumerate(segments, 1):
        print(f"Segment {i}: {segment}")