Mark7549 commited on
Commit
169869e
1 Parent(s): 14c3a4f

Added option to select models to search word in

Browse files
Files changed (2) hide show
  1. app.py +16 -3
  2. word2vec.py +39 -4
app.py CHANGED
@@ -20,6 +20,7 @@ if active_tab == "Nearest neighbours":
20
  with col2:
21
  time_slice = st.selectbox("Time slice", ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"])
22
 
 
23
  n = st.slider("Number of neighbours", 1, 50, 15)
24
 
25
  nearest_neighbours_button = st.button("Find nearest neighbours")
@@ -28,14 +29,26 @@ if active_tab == "Nearest neighbours":
28
  if nearest_neighbours_button:
29
 
30
  # Rewrite timeslices to model names: Archaic -> archaic_cbow
 
 
 
 
 
 
 
31
  time_slice = time_slice.lower() + "_cbow"
32
- st.write(time_slice)
 
33
 
34
  # Check if all fields are filled in
35
- if validate_nearest_neighbours(word, time_slice, n) == False:
36
  st.error('Please fill in all fields')
37
  else:
38
- nearest_neighbours = get_nearest_neighbours(word, time_slice, n)
 
 
 
 
39
  df = pd.DataFrame(nearest_neighbours, columns=["Word", "Time slice", "Similarity"])
40
  st.table(df)
41
 
 
20
  with col2:
21
  time_slice = st.selectbox("Time slice", ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"])
22
 
23
+ models = st.multiselect("Select models to search for neighbours", ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"])
24
  n = st.slider("Number of neighbours", 1, 50, 15)
25
 
26
  nearest_neighbours_button = st.button("Find nearest neighbours")
 
29
  if nearest_neighbours_button:
30
 
31
  # Rewrite timeslices to model names: Archaic -> archaic_cbow
32
+ if time_slice == 'Hellenistic':
33
+ time_slice = 'hellen'
34
+ elif time_slice == 'Early Roman':
35
+ time_slice = 'early_roman'
36
+ elif time_slice == 'Late Roman':
37
+ time_slice = 'late_roman'
38
+
39
  time_slice = time_slice.lower() + "_cbow"
40
+
41
+
42
 
43
  # Check if all fields are filled in
44
+ if validate_nearest_neighbours(word, time_slice, n, models) == False:
45
  st.error('Please fill in all fields')
46
  else:
47
+ # Rewrite models to list of all loaded models
48
+ models = load_selected_models(models)
49
+
50
+ nearest_neighbours = get_nearest_neighbours(word, time_slice, n, models)
51
+
52
  df = pd.DataFrame(nearest_neighbours, columns=["Word", "Time slice", "Similarity"])
53
  st.table(df)
54
 
word2vec.py CHANGED
@@ -18,6 +18,24 @@ def load_all_models():
18
  return [archaic, classical, early_roman, hellen, late_roman]
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def load_word2vec_model(model_path):
22
  '''
23
  Load a word2vec model from a file
@@ -120,15 +138,31 @@ def get_cosine_similarity_one_word(word, time_slice1, time_slice2):
120
 
121
 
122
 
123
- def validate_nearest_neighbours(word, time_slice_model, n):
124
  '''
125
  Validate the input of the nearest neighbours function
126
  '''
127
- if word == '' or time_slice_model == [] or n == '':
128
  return False
129
  return True
130
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models()):
133
  '''
134
  Return the nearest neighbours of a word
@@ -149,6 +183,7 @@ def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models(
149
  # Iterate over all models
150
  for model in models:
151
  model_name = model[0]
 
152
  model = model[1]
153
 
154
  # Iterate over all words of the model
@@ -162,14 +197,14 @@ def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models(
162
 
163
  # If the list of nearest neighbours is not full yet, add the current word
164
  if len(nearest_neighbours) < n:
165
- nearest_neighbours.append((word, model_name, cosine_similarity_vectors))
166
 
167
  # If the list of nearest neighbours is full, replace the word with the smallest cosine similarity
168
  else:
169
  smallest_neighbour = min(nearest_neighbours, key=lambda x: x[2])
170
  if cosine_similarity_vectors > smallest_neighbour[2]:
171
  nearest_neighbours.remove(smallest_neighbour)
172
- nearest_neighbours.append((word, model_name, cosine_similarity_vectors))
173
 
174
 
175
  return sorted(nearest_neighbours, key=lambda x: x[2], reverse=True)
 
18
  return [archaic, classical, early_roman, hellen, late_roman]
19
 
20
 
21
+ def load_selected_models(selected_models):
22
+ '''
23
+ Load the selected word2vec models
24
+ '''
25
+ models = []
26
+ for model in selected_models:
27
+ if model == "Early Roman":
28
+ model = "early_roman"
29
+ elif model == "Late Roman":
30
+ model = "late_roman"
31
+ elif model == "Hellenistic":
32
+ model = "hellen"
33
+ model_name = model.lower() + "_cbow"
34
+ models.append([model_name, load_word2vec_model(f'models/{model_name}.model')])
35
+
36
+ return models
37
+
38
+
39
  def load_word2vec_model(model_path):
40
  '''
41
  Load a word2vec model from a file
 
138
 
139
 
140
 
141
+ def validate_nearest_neighbours(word, time_slice_model, n, models):
142
  '''
143
  Validate the input of the nearest neighbours function
144
  '''
145
+ if word == '' or time_slice_model == [] or n == '' or models == []:
146
  return False
147
  return True
148
 
149
 
150
+ def convert_model_to_time_name(model_name):
151
+ '''
152
+ Convert the model name to the time slice name
153
+ '''
154
+ if model_name == 'archaic_cbow':
155
+ return 'Archaic'
156
+ elif model_name == 'classical_cbow':
157
+ return 'Classical'
158
+ elif model_name == 'early_roman_cbow':
159
+ return 'Early Roman'
160
+ elif model_name == 'hellen_cbow':
161
+ return 'Hellenistic'
162
+ elif model_name == 'late_roman_cbow':
163
+ return 'Late Roman'
164
+
165
+
166
  def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models()):
167
  '''
168
  Return the nearest neighbours of a word
 
183
  # Iterate over all models
184
  for model in models:
185
  model_name = model[0]
186
+ time_name = convert_model_to_time_name(model_name)
187
  model = model[1]
188
 
189
  # Iterate over all words of the model
 
197
 
198
  # If the list of nearest neighbours is not full yet, add the current word
199
  if len(nearest_neighbours) < n:
200
+ nearest_neighbours.append((word, time_name, cosine_similarity_vectors))
201
 
202
  # If the list of nearest neighbours is full, replace the word with the smallest cosine similarity
203
  else:
204
  smallest_neighbour = min(nearest_neighbours, key=lambda x: x[2])
205
  if cosine_similarity_vectors > smallest_neighbour[2]:
206
  nearest_neighbours.remove(smallest_neighbour)
207
+ nearest_neighbours.append((word, time_name, cosine_similarity_vectors))
208
 
209
 
210
  return sorted(nearest_neighbours, key=lambda x: x[2], reverse=True)