saeedabc's picture
Update readme info
d331bf4
import hashlib
import pickle
from pathlib import Path
from itertools import zip_longest
import gradio as gr
import torch
from sentence_transformers import SentenceTransformer, util
import numpy as np
import ruptures as rpt
from util import sent_tokenize
CACHE_DIR = '.cache'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
_ST_MODELS = ['all-mpnet-base-v2', 'multi-qa-mpnet-base-dot-v1', 'all-MiniLM-L12-v2']
def embed_sentences(sentences, embedder_fn, cache_path):
if Path(cache_path).exists():
print(f'Loading embeddings from cache: {cache_path}')
with open(cache_path, 'rb') as file:
embedded_sents = pickle.load(file)
else:
print(f'Embedding sentences and saving to cache: {cache_path}')
embedded_sents = embedder_fn(sentences)
assert len(embedded_sents) == len(sentences)
with open(cache_path, 'wb') as file:
pickle.dump(embedded_sents, file)
return embedded_sents
def calculate_cosine_similarities(embedded_sents, k=1, pool='mean'):
def cosine_similarity(a, b):
sim = util.cos_sim(a, b)
if pool == 'mean':
return sim.mean().item()
elif pool == 'max':
return sim.max().item()
elif pool == 'min':
return sim.min().item()
else:
raise ValueError(f'Invalid pooling method: {pool}')
cosine_sims = []
for i in range(len(embedded_sents) - 1):
lctx = embedded_sents[max(0, i-k+1) : i+1]
rctx = embedded_sents[i+1 : i+k+1]
sim = cosine_similarity(lctx, rctx)
cosine_sims.append(sim)
# cosine_sims.append(0.0)
assert len(cosine_sims) == len(embedded_sents) - 1, f'{len(cosine_sims)} != {len(embedded_sents)}'
return cosine_sims
def predict_boundaries(cosine_sims, threshold):
probs = [1.0 - sim for sim in cosine_sims]
preds = [1 if prob >= threshold else 0 for prob in probs]
return preds, probs
def output_segments(sents, preds, probs):
assert len(sents) - 1 == len(preds) == len(probs), f'{len(sents)} - 1 != {len(preds)} != {len(probs)}'
def iter_segments(sents, preds, probs):
segment = []
for i, (sent, pred, prob) in enumerate(zip_longest(sents, preds, probs)):
segment.append({
# 'id': i + 1,
'text': sent,
'prob': round(prob, 4) if prob is not None else None,
})
if pred == 1:
yield segment
segment = []
if len(segment) > 0:
yield segment
segment = []
out = {
'metadata': {},
'chunks': [],
}
n_segs = 0
n_sents = 0
for _, segment in enumerate(iter_segments(sents, preds, probs)):
# out['chunks'].append({
# 'id': n_segs + 1,
# 'chunk': segment,
# })
out['chunks'].append(segment)
n_segs += 1
n_sents += len(segment)
out['metadata'] = {
'n_chunks': n_segs,
'n_sents': n_sents,
'prob_mean': round(np.mean(probs), 4),
'prob_std': round(np.std(probs), 4),
'prob_min': round(min(probs), 4),
'prob_max': round(max(probs), 4),
}
out_text = "\n-------------------------\n".join(["\n".join([sent['text'] for sent in segment]) for segment in out['chunks']])
def plot_regimes(signal, preds):
def get_bkps_from_labels(labels):
return [i+1 for i, l in enumerate(labels) if l == 1]
# signal = signal[:-1]
preds = preds + [1]
bkps = get_bkps_from_labels(preds)
fig, [ax] = rpt.display(np.array(signal), bkps, figsize=(10, 5), dpi=250)
y_min = max(0.0, min(signal) - 0.1)
y_max = min(1.0, max(signal) + 0.1)
ax.set_ylim(y_min, y_max)
ax.set_title("Segment Regimes")
ax.set_xlabel("Sentence Index")
ax.set_ylabel("Semantic Shift Probability")
fig.tight_layout()
return fig
fig = plot_regimes(probs, preds)
return out_text, out, fig
def text_segmentation(input_text, model_name, k, pool, threshold):
if model_name in _ST_MODELS:
model = SentenceTransformer(model_name, device=DEVICE)
embedder_fn = model.encode
else:
raise ValueError(f'Invalid model name: {model_name}')
sents = sent_tokenize(input_text, method='nltk', initial_split_sep='\n')
cache_id = hashlib.md5(input_text.encode()).hexdigest()
cache_path = Path(CACHE_DIR) / f'{cache_id}.pkl'
embedded_sents = embed_sentences(sents, embedder_fn, cache_path=cache_path)
cosine_sims = calculate_cosine_similarities(embedded_sents, k=k, pool=pool)
preds, probs = predict_boundaries(cosine_sims, threshold=threshold)
return output_segments(sents, preds, probs)
with gr.Blocks() as app:
gr.Markdown("""
# LLM TextTiling Demo
An **extended** approach to text segmentation that combines **TextTiling** with **LLM embeddings**. Simply provide your text, choose an embedding model, and adjust segmentation parameters (window size, pooling, threshold). The demo will split your text into coherent segments based on **semantic shifts**. Refer to the [README](https://huggingface.co/spaces/saeedabc/llm-text-tiling-demo/blob/main/README.md) for more details.
""")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text", placeholder="Enter your text here...", lines=15)
with gr.Row():
with gr.Column():
# model_name = gr.Radio(choices=_ST_MODELS, label="Embedding Model", value=_ST_MODELS[0])
model_name = gr.Dropdown(choices=_ST_MODELS, label="Embedding Model", value=_ST_MODELS[0])
with gr.Column():
pool = gr.Dropdown(choices=['max', 'mean', 'min'], label="Pooling Strategy", value='max')
with gr.Row():
with gr.Column():
threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Threshold", value=0.5)
with gr.Column():
k = gr.Slider(minimum=1, maximum=10, step=1, label="Window Size", value=3)
submit_button = gr.Button("Chunk Text")
with gr.Column():
with gr.Tabs():
with gr.Tab("Output Text"):
output_text = gr.Textbox(label="Output Text", placeholder="Chunks will appear here...", lines=22)
with gr.Tab("Output Json"):
output_json = gr.Json(label="Output Json", open=False, max_height=500)
with gr.Tab("Output Visualization"):
output_fig = gr.Plot(label="Output Visualization")
submit_button.click(text_segmentation, inputs=[input_text, model_name, k, pool, threshold], outputs=[output_text, output_json, output_fig])
examples = gr.Examples(
examples=[
["Rib Mountain is a census-designated place (CDP) in the town of Rib Mountain in Marathon County, Wisconsin, United States. "
"The population was 5,651 at the 2010 census. "
"The community is named for Rib Mountain. "
"According to the United States Census Bureau, the CDP has a total area of 33.8 km² (13.0 mi²). "
"31.4 km² (12.1 mi²) of it is land and 2.4 km² (0.9 mi²) of it (6.98%) is water. "
"As of the census of 2000, there were 6,059 people, 2,211 households, and 1,782 families residing in the CDP. "
"The population density was 193.0/km² (499.8/mi²). "
"There were 2,278 housing units at an average density of 72.6/km² (187.9/mi²).", "all-mpnet-base-v2", 3, 'max', 0.52],
],
inputs=[input_text, model_name, k, pool, threshold],
)
if __name__ == '__main__':
Path(CACHE_DIR).mkdir(exist_ok=True)
# Launch the app
app.launch()