import hashlib import pickle from pathlib import Path from itertools import zip_longest import gradio as gr import torch from sentence_transformers import SentenceTransformer, util import numpy as np import ruptures as rpt from util import sent_tokenize CACHE_DIR = '.cache' DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' _ST_MODELS = ['all-mpnet-base-v2', 'multi-qa-mpnet-base-dot-v1', 'all-MiniLM-L12-v2'] def embed_sentences(sentences, embedder_fn, cache_path): if Path(cache_path).exists(): print(f'Loading embeddings from cache: {cache_path}') with open(cache_path, 'rb') as file: embedded_sents = pickle.load(file) else: print(f'Embedding sentences and saving to cache: {cache_path}') embedded_sents = embedder_fn(sentences) assert len(embedded_sents) == len(sentences) with open(cache_path, 'wb') as file: pickle.dump(embedded_sents, file) return embedded_sents def calculate_cosine_similarities(embedded_sents, k=1, pool='mean'): def cosine_similarity(a, b): sim = util.cos_sim(a, b) if pool == 'mean': return sim.mean().item() elif pool == 'max': return sim.max().item() elif pool == 'min': return sim.min().item() else: raise ValueError(f'Invalid pooling method: {pool}') cosine_sims = [] for i in range(len(embedded_sents) - 1): lctx = embedded_sents[max(0, i-k+1) : i+1] rctx = embedded_sents[i+1 : i+k+1] sim = cosine_similarity(lctx, rctx) cosine_sims.append(sim) # cosine_sims.append(0.0) assert len(cosine_sims) == len(embedded_sents) - 1, f'{len(cosine_sims)} != {len(embedded_sents)}' return cosine_sims def predict_boundaries(cosine_sims, threshold): probs = [1.0 - sim for sim in cosine_sims] preds = [1 if prob >= threshold else 0 for prob in probs] return preds, probs def output_segments(sents, preds, probs): assert len(sents) - 1 == len(preds) == len(probs), f'{len(sents)} - 1 != {len(preds)} != {len(probs)}' def iter_segments(sents, preds, probs): segment = [] for i, (sent, pred, prob) in enumerate(zip_longest(sents, preds, probs)): segment.append({ # 'id': i + 1, 'text': sent, 'prob': round(prob, 4) if prob is not None else None, }) if pred == 1: yield segment segment = [] if len(segment) > 0: yield segment segment = [] out = { 'metadata': {}, 'chunks': [], } n_segs = 0 n_sents = 0 for _, segment in enumerate(iter_segments(sents, preds, probs)): # out['chunks'].append({ # 'id': n_segs + 1, # 'chunk': segment, # }) out['chunks'].append(segment) n_segs += 1 n_sents += len(segment) out['metadata'] = { 'n_chunks': n_segs, 'n_sents': n_sents, 'prob_mean': round(np.mean(probs), 4), 'prob_std': round(np.std(probs), 4), 'prob_min': round(min(probs), 4), 'prob_max': round(max(probs), 4), } out_text = "\n-------------------------\n".join(["\n".join([sent['text'] for sent in segment]) for segment in out['chunks']]) def plot_regimes(signal, preds): def get_bkps_from_labels(labels): return [i+1 for i, l in enumerate(labels) if l == 1] # signal = signal[:-1] preds = preds + [1] bkps = get_bkps_from_labels(preds) fig, [ax] = rpt.display(np.array(signal), bkps, figsize=(10, 5), dpi=250) y_min = max(0.0, min(signal) - 0.1) y_max = min(1.0, max(signal) + 0.1) ax.set_ylim(y_min, y_max) ax.set_title("Segment Regimes") ax.set_xlabel("Sentence Index") ax.set_ylabel("Semantic Shift Probability") fig.tight_layout() return fig fig = plot_regimes(probs, preds) return out_text, out, fig def text_segmentation(input_text, model_name, k, pool, threshold): if model_name in _ST_MODELS: model = SentenceTransformer(model_name, device=DEVICE) embedder_fn = model.encode else: raise ValueError(f'Invalid model name: {model_name}') sents = sent_tokenize(input_text, method='nltk', initial_split_sep='\n') cache_id = hashlib.md5(input_text.encode()).hexdigest() cache_path = Path(CACHE_DIR) / f'{cache_id}.pkl' embedded_sents = embed_sentences(sents, embedder_fn, cache_path=cache_path) cosine_sims = calculate_cosine_similarities(embedded_sents, k=k, pool=pool) preds, probs = predict_boundaries(cosine_sims, threshold=threshold) return output_segments(sents, preds, probs) with gr.Blocks() as app: gr.Markdown(""" # LLM TextTiling Demo An **extended** approach to text segmentation that combines **TextTiling** with **LLM embeddings**. Simply provide your text, choose an embedding model, and adjust segmentation parameters (window size, pooling, threshold). The demo will split your text into coherent segments based on **semantic shifts**. Refer to the [README](https://huggingface.co/spaces/saeedabc/llm-text-tiling-demo/blob/main/README.md) for more details. """) with gr.Row(): with gr.Column(): input_text = gr.Textbox(label="Input Text", placeholder="Enter your text here...", lines=15) with gr.Row(): with gr.Column(): # model_name = gr.Radio(choices=_ST_MODELS, label="Embedding Model", value=_ST_MODELS[0]) model_name = gr.Dropdown(choices=_ST_MODELS, label="Embedding Model", value=_ST_MODELS[0]) with gr.Column(): pool = gr.Dropdown(choices=['max', 'mean', 'min'], label="Pooling Strategy", value='max') with gr.Row(): with gr.Column(): threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Threshold", value=0.5) with gr.Column(): k = gr.Slider(minimum=1, maximum=10, step=1, label="Window Size", value=3) submit_button = gr.Button("Chunk Text") with gr.Column(): with gr.Tabs(): with gr.Tab("Output Text"): output_text = gr.Textbox(label="Output Text", placeholder="Chunks will appear here...", lines=22) with gr.Tab("Output Json"): output_json = gr.Json(label="Output Json", open=False, max_height=500) with gr.Tab("Output Visualization"): output_fig = gr.Plot(label="Output Visualization") submit_button.click(text_segmentation, inputs=[input_text, model_name, k, pool, threshold], outputs=[output_text, output_json, output_fig]) examples = gr.Examples( examples=[ ["Rib Mountain is a census-designated place (CDP) in the town of Rib Mountain in Marathon County, Wisconsin, United States. " "The population was 5,651 at the 2010 census. " "The community is named for Rib Mountain. " "According to the United States Census Bureau, the CDP has a total area of 33.8 km² (13.0 mi²). " "31.4 km² (12.1 mi²) of it is land and 2.4 km² (0.9 mi²) of it (6.98%) is water. " "As of the census of 2000, there were 6,059 people, 2,211 households, and 1,782 families residing in the CDP. " "The population density was 193.0/km² (499.8/mi²). " "There were 2,278 housing units at an average density of 72.6/km² (187.9/mi²).", "all-mpnet-base-v2", 3, 'max', 0.52], ], inputs=[input_text, model_name, k, pool, threshold], ) if __name__ == '__main__': Path(CACHE_DIR).mkdir(exist_ok=True) # Launch the app app.launch()