cstimson commited on
Commit
d64128a
1 Parent(s): 5c14c16

Create new file

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import nltk
3
+ from transformers import pipeline
4
+ from sentence_transformers import SentenceTransformer
5
+ from scipy.spatial.distance import cosine
6
+ import numpy as np
7
+ import seaborn as sns
8
+ import matplotlib.pyplot as plt
9
+ from sklearn.cluster import KMeans
10
+ import tensorflow as tf
11
+ import tensorflow_hub as hub
12
+
13
+
14
+ def cluster_examples(messages, embed, nc=3):
15
+ km = KMeans(
16
+ n_clusters=nc, init='random',
17
+ n_init=10, max_iter=300,
18
+ tol=1e-04, random_state=0
19
+ )
20
+ km = km.fit_predict(embed)
21
+ for n in range(nc):
22
+ idxs = [i for i in range(len(km)) if km[i] == n]
23
+ ms = [messages[i] for i in idxs]
24
+ st.markdown ("CLUSTER : %d"%n)
25
+ for m in ms:
26
+ st.markdown (m)
27
+
28
+
29
+ def plot_heatmap(labels, heatmap, rotation=90):
30
+ sns.set(font_scale=1.2)
31
+ fig, ax = plt.subplots()
32
+ g = sns.heatmap(
33
+ heatmap,
34
+ xticklabels=labels,
35
+ yticklabels=labels,
36
+ vmin=-1,
37
+ vmax=1,
38
+ cmap="coolwarm")
39
+ g.set_xticklabels(labels, rotation=rotation)
40
+ g.set_title("Textual Similarity")
41
+
42
+ st.pyplot(fig)
43
+ #plt.show()
44
+
45
+ #st.header("Sentence Similarity Demo")
46
+
47
+ # Streamlit text boxes
48
+ text = st.text_area('Enter sentences:', value="The sun is hotter than the moon.\nThe sun is very bright.\nI hear that the universe is very large.\nToday is Tuesday.")
49
+
50
+ nc = st.slider('Select a number of clusters:', min_value=1, max_value=15, value=3)
51
+
52
+ model_type = st.radio("Choose model:", ('Sentence Transformer', 'Universal Sentence Encoder'), index=0)
53
+
54
+ # Model setup
55
+ if model_type == "Sentence Transformer":
56
+ model = SentenceTransformer('paraphrase-distilroberta-base-v1')
57
+ elif model_type == "Universal Sentence Encoder":
58
+ model_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
59
+ model = hub.load(model_url)
60
+
61
+ nltk.download('punkt')
62
+
63
+ # Run model
64
+ if text:
65
+ sentences = nltk.tokenize.sent_tokenize(text)
66
+ if model_type == "Sentence Transformer":
67
+ embed = model.encode(sentences)
68
+ elif model_type == "Universal Sentence Encoder":
69
+ embed = model(sentences).numpy()
70
+ sim = np.zeros([len(embed), len(embed)])
71
+ for i,em in enumerate(embed):
72
+ for j,ea in enumerate(embed):
73
+ sim[i][j] = 1.0-cosine(em,ea)
74
+ st.subheader("Similarity Heatmap")
75
+ plot_heatmap(sentences, sim)
76
+ st.subheader("Results from K-Means Clustering")
77
+ cluster_examples(sentences, embed, nc)