Spaces:
Running
Running
import hashlib | |
import pickle | |
from pathlib import Path | |
from itertools import zip_longest | |
import gradio as gr | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
import numpy as np | |
import ruptures as rpt | |
from util import sent_tokenize | |
CACHE_DIR = '.cache' | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
_ST_MODELS = ['all-mpnet-base-v2', 'multi-qa-mpnet-base-dot-v1', 'all-MiniLM-L12-v2'] | |
def embed_sentences(sentences, embedder_fn, cache_path): | |
if Path(cache_path).exists(): | |
print(f'Loading embeddings from cache: {cache_path}') | |
with open(cache_path, 'rb') as file: | |
embedded_sents = pickle.load(file) | |
else: | |
print(f'Embedding sentences and saving to cache: {cache_path}') | |
embedded_sents = embedder_fn(sentences) | |
assert len(embedded_sents) == len(sentences) | |
with open(cache_path, 'wb') as file: | |
pickle.dump(embedded_sents, file) | |
return embedded_sents | |
def calculate_cosine_similarities(embedded_sents, k=1, pool='mean'): | |
def cosine_similarity(a, b): | |
sim = util.cos_sim(a, b) | |
if pool == 'mean': | |
return sim.mean().item() | |
elif pool == 'max': | |
return sim.max().item() | |
elif pool == 'min': | |
return sim.min().item() | |
else: | |
raise ValueError(f'Invalid pooling method: {pool}') | |
cosine_sims = [] | |
for i in range(len(embedded_sents) - 1): | |
lctx = embedded_sents[max(0, i-k+1) : i+1] | |
rctx = embedded_sents[i+1 : i+k+1] | |
sim = cosine_similarity(lctx, rctx) | |
cosine_sims.append(sim) | |
# cosine_sims.append(0.0) | |
assert len(cosine_sims) == len(embedded_sents) - 1, f'{len(cosine_sims)} != {len(embedded_sents)}' | |
return cosine_sims | |
def predict_boundaries(cosine_sims, threshold): | |
probs = [1.0 - sim for sim in cosine_sims] | |
preds = [1 if prob >= threshold else 0 for prob in probs] | |
return preds, probs | |
def output_segments(sents, preds, probs): | |
assert len(sents) - 1 == len(preds) == len(probs), f'{len(sents)} - 1 != {len(preds)} != {len(probs)}' | |
def iter_segments(sents, preds, probs): | |
segment = [] | |
for i, (sent, pred, prob) in enumerate(zip_longest(sents, preds, probs)): | |
segment.append({ | |
# 'id': i + 1, | |
'text': sent, | |
'prob': round(prob, 4) if prob is not None else None, | |
}) | |
if pred == 1: | |
yield segment | |
segment = [] | |
if len(segment) > 0: | |
yield segment | |
segment = [] | |
out = { | |
'metadata': {}, | |
'chunks': [], | |
} | |
n_segs = 0 | |
n_sents = 0 | |
for _, segment in enumerate(iter_segments(sents, preds, probs)): | |
# out['chunks'].append({ | |
# 'id': n_segs + 1, | |
# 'chunk': segment, | |
# }) | |
out['chunks'].append(segment) | |
n_segs += 1 | |
n_sents += len(segment) | |
out['metadata'] = { | |
'n_chunks': n_segs, | |
'n_sents': n_sents, | |
'prob_mean': round(np.mean(probs), 4), | |
'prob_std': round(np.std(probs), 4), | |
'prob_min': round(min(probs), 4), | |
'prob_max': round(max(probs), 4), | |
} | |
out_text = "\n-------------------------\n".join(["\n".join([sent['text'] for sent in segment]) for segment in out['chunks']]) | |
def plot_regimes(signal, preds): | |
def get_bkps_from_labels(labels): | |
return [i+1 for i, l in enumerate(labels) if l == 1] | |
# signal = signal[:-1] | |
preds = preds + [1] | |
bkps = get_bkps_from_labels(preds) | |
fig, [ax] = rpt.display(np.array(signal), bkps, figsize=(10, 5), dpi=250) | |
y_min = max(0.0, min(signal) - 0.1) | |
y_max = min(1.0, max(signal) + 0.1) | |
ax.set_ylim(y_min, y_max) | |
ax.set_title("Segment Regimes") | |
ax.set_xlabel("Sentence Index") | |
ax.set_ylabel("Semantic Shift Probability") | |
fig.tight_layout() | |
return fig | |
fig = plot_regimes(probs, preds) | |
return out_text, out, fig | |
def text_segmentation(input_text, model_name, k, pool, threshold): | |
if model_name in _ST_MODELS: | |
model = SentenceTransformer(model_name, device=DEVICE) | |
embedder_fn = model.encode | |
else: | |
raise ValueError(f'Invalid model name: {model_name}') | |
sents = sent_tokenize(input_text, method='nltk', initial_split_sep='\n') | |
cache_id = hashlib.md5(input_text.encode()).hexdigest() | |
cache_path = Path(CACHE_DIR) / f'{cache_id}.pkl' | |
embedded_sents = embed_sentences(sents, embedder_fn, cache_path=cache_path) | |
cosine_sims = calculate_cosine_similarities(embedded_sents, k=k, pool=pool) | |
preds, probs = predict_boundaries(cosine_sims, threshold=threshold) | |
return output_segments(sents, preds, probs) | |
with gr.Blocks() as app: | |
gr.Markdown(""" | |
# LLM TextTiling Demo | |
An **extended** approach to text segmentation that combines **TextTiling** with **LLM embeddings**. Simply provide your text, choose an embedding model, and adjust segmentation parameters (window size, pooling, threshold). The demo will split your text into coherent segments based on **semantic shifts**. Refer to the [README](https://huggingface.co/spaces/saeedabc/llm-text-tiling-demo/blob/main/README.md) for more details. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(label="Input Text", placeholder="Enter your text here...", lines=15) | |
with gr.Row(): | |
with gr.Column(): | |
# model_name = gr.Radio(choices=_ST_MODELS, label="Embedding Model", value=_ST_MODELS[0]) | |
model_name = gr.Dropdown(choices=_ST_MODELS, label="Embedding Model", value=_ST_MODELS[0]) | |
with gr.Column(): | |
pool = gr.Dropdown(choices=['max', 'mean', 'min'], label="Pooling Strategy", value='max') | |
with gr.Row(): | |
with gr.Column(): | |
threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Threshold", value=0.5) | |
with gr.Column(): | |
k = gr.Slider(minimum=1, maximum=10, step=1, label="Window Size", value=3) | |
submit_button = gr.Button("Chunk Text") | |
with gr.Column(): | |
with gr.Tabs(): | |
with gr.Tab("Output Text"): | |
output_text = gr.Textbox(label="Output Text", placeholder="Chunks will appear here...", lines=22) | |
with gr.Tab("Output Json"): | |
output_json = gr.Json(label="Output Json", open=False, max_height=500) | |
with gr.Tab("Output Visualization"): | |
output_fig = gr.Plot(label="Output Visualization") | |
submit_button.click(text_segmentation, inputs=[input_text, model_name, k, pool, threshold], outputs=[output_text, output_json, output_fig]) | |
examples = gr.Examples( | |
examples=[ | |
["Rib Mountain is a census-designated place (CDP) in the town of Rib Mountain in Marathon County, Wisconsin, United States. " | |
"The population was 5,651 at the 2010 census. " | |
"The community is named for Rib Mountain. " | |
"According to the United States Census Bureau, the CDP has a total area of 33.8 km² (13.0 mi²). " | |
"31.4 km² (12.1 mi²) of it is land and 2.4 km² (0.9 mi²) of it (6.98%) is water. " | |
"As of the census of 2000, there were 6,059 people, 2,211 households, and 1,782 families residing in the CDP. " | |
"The population density was 193.0/km² (499.8/mi²). " | |
"There were 2,278 housing units at an average density of 72.6/km² (187.9/mi²).", "all-mpnet-base-v2", 3, 'max', 0.52], | |
], | |
inputs=[input_text, model_name, k, pool, threshold], | |
) | |
if __name__ == '__main__': | |
Path(CACHE_DIR).mkdir(exist_ok=True) | |
# Launch the app | |
app.launch() | |