Michael-Geis commited on
Commit
baf74d8
1 Parent(s): c8cfc43

switched tag embedding format to parquet to save index

Browse files
Files changed (1) hide show
  1. embedding.py +4 -11
embedding.py CHANGED
@@ -111,16 +111,9 @@ def generate_tag_embeddings(model_name, path_to_tag_dict, path_to_save_embedding
111
  embedded_tag_names_df.columns = [
112
  str(name) for name in embedded_tag_names_df.columns
113
  ]
 
 
114
 
115
- embedded_tag_names_df.to_feather(path_to_save_embeddings)
116
 
117
-
118
- def load_tag_embeddings(path_to_tag_dict, path_to_tag_embeddings):
119
- with open(path_to_tag_dict, "r") as file:
120
- dict_string = file.read()
121
- tag_dict = json.loads(dict_string)
122
-
123
- tag_name_list = list(tag_dict.values())
124
- tag_name_embeddings = pd.read_feather(path_to_tag_embeddings)
125
- tag_name_embeddings.index = tag_name_list
126
- return tag_name_embeddings
 
111
  embedded_tag_names_df.columns = [
112
  str(name) for name in embedded_tag_names_df.columns
113
  ]
114
+ embedded_tag_names_df.index = tag_name_list
115
+ embedded_tag_names_df.to_parquet(path_to_save_embeddings, index=True)
116
 
 
117
 
118
+ def load_tag_embeddings(path_to_tag_embeddings):
119
+ return pd.read_parquet(path_to_tag_embeddings)