DataAIDemo / pages /topic_classification.py
themeetjani's picture
Update pages/topic_classification.py
d1cfdb3 verified
raw
history blame
2.73 kB
#importing all the neccesary packages here
import streamlit as st
from streamlit import session_state
import pandas as pd
import numpy as np
from scipy import spatial
from sentence_transformers import SentenceTransformer
import json
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') #calling hugging face model for embeddings here
#cosine function for
def cosine_similarity(x,y):
return 1 - spatial.distance.cosine(x,y)
# reading topic file into dataframe
df = pd.read_excel('topic_data.xlsx')
#df2 = pd.read_csv("BBC News Train.csv") #sample news article file
#storing level1 and level2 segments into dictinary first
result_dict = df.groupby('LEVEL 1')['new_level_2'].apply(list).to_dict()
#storing l1 segments
segments = list(result_dict.keys())
segments_encode = model.encode(segments) #encoding l1 segments with model
#creating embedding dictionary of all l1 segments and l2 segments.
#embedding dictionary for l2 segments
embeddings_dict = {}
for key, val in result_dict.items():
embed = model.encode(result_dict[key])
embeddings_dict[key] = embed
#function for calculating l1 segments.
def segments_finder(text_encode):
score_dict = {}
for segment,name in zip(segments_encode,segments):
similarity_score = cosine_similarity(segment,text_encode)
score_dict[name] = similarity_score
return sorted(score_dict.items(), key=lambda x: x[1], reverse=True)
def level2(article_summary):
l1 = {}
l2 = {}
output = {}
text_encode = model.encode(article_summary)
l1_pred = segments_finder(text_encode)
#iterating in l1 segments to find their l2 segments.
for i in l1_pred[:2]:
score_dict = {}
l2_segments = result_dict[i[0]]
l2_segments_encode = embeddings_dict[i[0]]
for segment,name in zip(l2_segments_encode,l2_segments):
similarity_score = cosine_similarity(segment,text_encode)
score_dict[name] = similarity_score
l2_pred = dict(list(sorted(score_dict.items(), key=lambda x: x[1], reverse=True))[:2])
print(l2_pred)
l2[i[0]] = l2_pred
output['l1'] = dict(list(sorted(dict(l1_pred).items(), key=lambda x: x[1], reverse=True))[:2])
output['l2'] = l2
return output
st.set_page_config(page_title="topic_classification", page_icon="📈")
if 'topic_class' not in session_state:
session_state['topic_class']= ""
st.title("Topic Classifier")
text= st.text_area(label= "Please write the text bellow",
placeholder="What does the tweet say?")
def classify(text):
session_state['topic_class'] = level2(text)
st.text_area("result", value=session_state['topic_class'])
st.button("Submit", on_click=classify, args=[text])