Spaces:
Runtime error
Runtime error
import streamlit as st | |
from duckduckgo_search import ddg | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
import umap.umap_ as umap | |
import numpy as np | |
import sys | |
import plotly.express as px | |
import re | |
import sklearn.cluster as cluster | |
import nltk | |
from nltk.stem import WordNetLemmatizer | |
from keybert import KeyBERT | |
nltk.download('punkt') | |
nltk.download('omw-1.4') | |
nltk.download('wordnet') | |
# Set a seed | |
np.random.seed(42) | |
# The search bar | |
keywords = st.text_input('Enter your search', 'How to use ChatGPT') | |
to_display = 'body' # Sometimes this is title | |
md = ddg(keywords, region='wt-wt', safesearch='Moderate', time='y', max_results=500) | |
md = pd.DataFrame(md) | |
# Load the model | |
print("running sentence embeddings...") | |
# model_name = 'all-mpnet-base-v2' | |
model_name = 'all-MiniLM-L6-v2' | |
model = SentenceTransformer(model_name) | |
sentence_embeddings = model.encode(md['body'].tolist(), show_progress_bar = True) | |
sentence_embeddings = pd.DataFrame(sentence_embeddings) | |
# Reduce dimensionality | |
print("reducing dimensionality...") | |
reducer = umap.UMAP(metric = 'cosine') | |
dimr = reducer.fit_transform(sentence_embeddings) | |
dimr = pd.DataFrame(dimr, columns = ['umap1', 'umap2']) | |
columns = ['title', 'href', 'body'] | |
# Clustering | |
labels = cluster.KMeans(n_clusters=5).fit_predict(dimr[['umap1', 'umap2']]) | |
dimr['cluster'] = labels | |
# Make the coloring easier on the eyes | |
dimr['cluster'] = dimr['cluster'].astype('category') | |
# Now we can search cluster in the table | |
dimr['cluster'] = ['cluster ' + str(x) for x in dimr['cluster']] | |
# Merge the data together | |
dat = pd.concat([md.reset_index(), dimr.reset_index()], axis = 1) | |
# The keywords | |
# Add keywords to the clusters | |
# Create WordNetLemmatizer object | |
print('extracting keywords per cluster...') | |
wnl = WordNetLemmatizer() | |
kw_model = KeyBERT() | |
keywords_df = [] | |
for i in np.unique(dat['cluster']): | |
curr = dat[dat['cluster'] == i] | |
text = ' '.join(curr['body']) | |
# Lemmatization | |
text = nltk.word_tokenize(text) | |
text = [wnl.lemmatize(i) for i in text] | |
text = ' '.join(text) | |
# Keyword extraction | |
TR_keywords = kw_model.extract_keywords(text) | |
keywords_df.append(TR_keywords[0:10]) | |
keywords_df = pd.DataFrame(keywords_df) | |
keywords_df['cluster'] = np.unique(dimr['cluster']) | |
keywords_df.columns = ['keyword1', 'keyword2', 'keyword3', 'keyword4', 'keyword5', 'cluster'] | |
# Get the keyword data into the dataframe | |
dat = dat.merge(keywords_df) # This messes up the index, so we need to reset it | |
dat = dat.reset_index(drop = True) | |
# handle duplicate index columns | |
dat = dat.loc[:,~dat.columns.duplicated()] | |
# Get it ready for plotting | |
dat['title'] = dat.title.str.wrap(30).apply(lambda x: x.replace('\n', '<br>')) | |
dat['body'] = dat.body.str.wrap(30).apply(lambda x: x.replace('\n', '<br>')) | |
# Visualize the data | |
fig = px.scatter(dat, x = 'umap1', y = 'umap2', hover_data = ['title', 'body', 'keyword1', 'keyword2', 'keyword3', 'keyword4', 'keyword5'], color = 'cluster', title = 'Context similarity map of results') | |
# Make the font a little bigger | |
fig.update_layout( | |
hoverlabel=dict( | |
bgcolor="white", | |
font_size=16 | |
) | |
) | |
# x and y are same size | |
fig.update_yaxes( | |
scaleanchor="x", | |
scaleratio=1, | |
) | |
# Show the figure | |
st.plotly_chart(fig, use_container_width=True) | |
# Remove <br> in the text for the table | |
dat['title'] = [re.sub('<br>', ' ', i) for i in dat['title']] | |
dat['body'] = [re.sub('<br>', ' ', i) for i in dat['body']] | |
# Instructions | |
st.caption('Use ctrl+f (or command+f for mac) to search the table') | |
# remove irrelevant columns from dat | |
dat = dat.drop(columns=['index', 'umap1', 'umap2', 'keyword1', 'keyword2', 'keyword3', 'keyword4', 'keyword5']) | |
# Make the link clickable | |
# pandas display options | |
pd.set_option('display.max_colwidth', -1) | |
def make_clickable(url, text): | |
return f'<a target="_blank" href="{url}">{text}</a>' | |
dat['href'] = dat['href'].apply(make_clickable, args = ('Click here',)) | |
st.write(dat.to_html(escape = False), unsafe_allow_html = True) | |