Added option to select models to search word in
Browse files- app.py +16 -3
- word2vec.py +39 -4
app.py
CHANGED
@@ -20,6 +20,7 @@ if active_tab == "Nearest neighbours":
|
|
20 |
with col2:
|
21 |
time_slice = st.selectbox("Time slice", ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"])
|
22 |
|
|
|
23 |
n = st.slider("Number of neighbours", 1, 50, 15)
|
24 |
|
25 |
nearest_neighbours_button = st.button("Find nearest neighbours")
|
@@ -28,14 +29,26 @@ if active_tab == "Nearest neighbours":
|
|
28 |
if nearest_neighbours_button:
|
29 |
|
30 |
# Rewrite timeslices to model names: Archaic -> archaic_cbow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
time_slice = time_slice.lower() + "_cbow"
|
32 |
-
|
|
|
33 |
|
34 |
# Check if all fields are filled in
|
35 |
-
if validate_nearest_neighbours(word, time_slice, n) == False:
|
36 |
st.error('Please fill in all fields')
|
37 |
else:
|
38 |
-
|
|
|
|
|
|
|
|
|
39 |
df = pd.DataFrame(nearest_neighbours, columns=["Word", "Time slice", "Similarity"])
|
40 |
st.table(df)
|
41 |
|
|
|
20 |
with col2:
|
21 |
time_slice = st.selectbox("Time slice", ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"])
|
22 |
|
23 |
+
models = st.multiselect("Select models to search for neighbours", ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"])
|
24 |
n = st.slider("Number of neighbours", 1, 50, 15)
|
25 |
|
26 |
nearest_neighbours_button = st.button("Find nearest neighbours")
|
|
|
29 |
if nearest_neighbours_button:
|
30 |
|
31 |
# Rewrite timeslices to model names: Archaic -> archaic_cbow
|
32 |
+
if time_slice == 'Hellenistic':
|
33 |
+
time_slice = 'hellen'
|
34 |
+
elif time_slice == 'Early Roman':
|
35 |
+
time_slice = 'early_roman'
|
36 |
+
elif time_slice == 'Late Roman':
|
37 |
+
time_slice = 'late_roman'
|
38 |
+
|
39 |
time_slice = time_slice.lower() + "_cbow"
|
40 |
+
|
41 |
+
|
42 |
|
43 |
# Check if all fields are filled in
|
44 |
+
if validate_nearest_neighbours(word, time_slice, n, models) == False:
|
45 |
st.error('Please fill in all fields')
|
46 |
else:
|
47 |
+
# Rewrite models to list of all loaded models
|
48 |
+
models = load_selected_models(models)
|
49 |
+
|
50 |
+
nearest_neighbours = get_nearest_neighbours(word, time_slice, n, models)
|
51 |
+
|
52 |
df = pd.DataFrame(nearest_neighbours, columns=["Word", "Time slice", "Similarity"])
|
53 |
st.table(df)
|
54 |
|
word2vec.py
CHANGED
@@ -18,6 +18,24 @@ def load_all_models():
|
|
18 |
return [archaic, classical, early_roman, hellen, late_roman]
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def load_word2vec_model(model_path):
|
22 |
'''
|
23 |
Load a word2vec model from a file
|
@@ -120,15 +138,31 @@ def get_cosine_similarity_one_word(word, time_slice1, time_slice2):
|
|
120 |
|
121 |
|
122 |
|
123 |
-
def validate_nearest_neighbours(word, time_slice_model, n):
|
124 |
'''
|
125 |
Validate the input of the nearest neighbours function
|
126 |
'''
|
127 |
-
if word == '' or time_slice_model == [] or n == '':
|
128 |
return False
|
129 |
return True
|
130 |
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models()):
|
133 |
'''
|
134 |
Return the nearest neighbours of a word
|
@@ -149,6 +183,7 @@ def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models(
|
|
149 |
# Iterate over all models
|
150 |
for model in models:
|
151 |
model_name = model[0]
|
|
|
152 |
model = model[1]
|
153 |
|
154 |
# Iterate over all words of the model
|
@@ -162,14 +197,14 @@ def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models(
|
|
162 |
|
163 |
# If the list of nearest neighbours is not full yet, add the current word
|
164 |
if len(nearest_neighbours) < n:
|
165 |
-
nearest_neighbours.append((word,
|
166 |
|
167 |
# If the list of nearest neighbours is full, replace the word with the smallest cosine similarity
|
168 |
else:
|
169 |
smallest_neighbour = min(nearest_neighbours, key=lambda x: x[2])
|
170 |
if cosine_similarity_vectors > smallest_neighbour[2]:
|
171 |
nearest_neighbours.remove(smallest_neighbour)
|
172 |
-
nearest_neighbours.append((word,
|
173 |
|
174 |
|
175 |
return sorted(nearest_neighbours, key=lambda x: x[2], reverse=True)
|
|
|
18 |
return [archaic, classical, early_roman, hellen, late_roman]
|
19 |
|
20 |
|
21 |
+
def load_selected_models(selected_models):
|
22 |
+
'''
|
23 |
+
Load the selected word2vec models
|
24 |
+
'''
|
25 |
+
models = []
|
26 |
+
for model in selected_models:
|
27 |
+
if model == "Early Roman":
|
28 |
+
model = "early_roman"
|
29 |
+
elif model == "Late Roman":
|
30 |
+
model = "late_roman"
|
31 |
+
elif model == "Hellenistic":
|
32 |
+
model = "hellen"
|
33 |
+
model_name = model.lower() + "_cbow"
|
34 |
+
models.append([model_name, load_word2vec_model(f'models/{model_name}.model')])
|
35 |
+
|
36 |
+
return models
|
37 |
+
|
38 |
+
|
39 |
def load_word2vec_model(model_path):
|
40 |
'''
|
41 |
Load a word2vec model from a file
|
|
|
138 |
|
139 |
|
140 |
|
141 |
+
def validate_nearest_neighbours(word, time_slice_model, n, models):
|
142 |
'''
|
143 |
Validate the input of the nearest neighbours function
|
144 |
'''
|
145 |
+
if word == '' or time_slice_model == [] or n == '' or models == []:
|
146 |
return False
|
147 |
return True
|
148 |
|
149 |
|
150 |
+
def convert_model_to_time_name(model_name):
|
151 |
+
'''
|
152 |
+
Convert the model name to the time slice name
|
153 |
+
'''
|
154 |
+
if model_name == 'archaic_cbow':
|
155 |
+
return 'Archaic'
|
156 |
+
elif model_name == 'classical_cbow':
|
157 |
+
return 'Classical'
|
158 |
+
elif model_name == 'early_roman_cbow':
|
159 |
+
return 'Early Roman'
|
160 |
+
elif model_name == 'hellen_cbow':
|
161 |
+
return 'Hellenistic'
|
162 |
+
elif model_name == 'late_roman_cbow':
|
163 |
+
return 'Late Roman'
|
164 |
+
|
165 |
+
|
166 |
def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models()):
|
167 |
'''
|
168 |
Return the nearest neighbours of a word
|
|
|
183 |
# Iterate over all models
|
184 |
for model in models:
|
185 |
model_name = model[0]
|
186 |
+
time_name = convert_model_to_time_name(model_name)
|
187 |
model = model[1]
|
188 |
|
189 |
# Iterate over all words of the model
|
|
|
197 |
|
198 |
# If the list of nearest neighbours is not full yet, add the current word
|
199 |
if len(nearest_neighbours) < n:
|
200 |
+
nearest_neighbours.append((word, time_name, cosine_similarity_vectors))
|
201 |
|
202 |
# If the list of nearest neighbours is full, replace the word with the smallest cosine similarity
|
203 |
else:
|
204 |
smallest_neighbour = min(nearest_neighbours, key=lambda x: x[2])
|
205 |
if cosine_similarity_vectors > smallest_neighbour[2]:
|
206 |
nearest_neighbours.remove(smallest_neighbour)
|
207 |
+
nearest_neighbours.append((word, time_name, cosine_similarity_vectors))
|
208 |
|
209 |
|
210 |
return sorted(nearest_neighbours, key=lambda x: x[2], reverse=True)
|