Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 -u | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import glob | |
| import numpy as np | |
| DIM = 1024 | |
| def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False): | |
| target_ids = [tid for tid in target_embs] | |
| source_mat = np.stack(source_embs.values(), axis=0) | |
| normalized_source_mat = source_mat / np.linalg.norm( | |
| source_mat, axis=1, keepdims=True | |
| ) | |
| target_mat = np.stack(target_embs.values(), axis=0) | |
| normalized_target_mat = target_mat / np.linalg.norm( | |
| target_mat, axis=1, keepdims=True | |
| ) | |
| sim_mat = normalized_source_mat.dot(normalized_target_mat.T) | |
| if return_sim_mat: | |
| return sim_mat | |
| neighbors_map = {} | |
| for i, sentence_id in enumerate(source_embs): | |
| idx = np.argsort(sim_mat[i, :])[::-1][:k] | |
| neighbors_map[sentence_id] = [target_ids[tid] for tid in idx] | |
| return neighbors_map | |
| def load_embeddings(directory, LANGS): | |
| sentence_embeddings = {} | |
| sentence_texts = {} | |
| for lang in LANGS: | |
| sentence_embeddings[lang] = {} | |
| sentence_texts[lang] = {} | |
| lang_dir = f"{directory}/{lang}" | |
| embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*") | |
| for embed_file in embedding_files: | |
| shard_id = embed_file.split(".")[-1] | |
| embeddings = np.fromfile(embed_file, dtype=np.float32) | |
| num_rows = embeddings.shape[0] // DIM | |
| embeddings = embeddings.reshape((num_rows, DIM)) | |
| with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file: | |
| for idx, line in enumerate(sentence_file): | |
| sentence_id, sentence = line.strip().split("\t") | |
| sentence_texts[lang][sentence_id] = sentence | |
| sentence_embeddings[lang][sentence_id] = embeddings[idx, :] | |
| return sentence_embeddings, sentence_texts | |
| def compute_accuracy(directory, LANGS): | |
| sentence_embeddings, sentence_texts = load_embeddings(directory, LANGS) | |
| top_1_accuracy = {} | |
| top1_str = " ".join(LANGS) + "\n" | |
| for source_lang in LANGS: | |
| top_1_accuracy[source_lang] = {} | |
| top1_str += f"{source_lang} " | |
| for target_lang in LANGS: | |
| top1 = 0 | |
| top5 = 0 | |
| neighbors_map = compute_dist( | |
| sentence_embeddings[source_lang], sentence_embeddings[target_lang] | |
| ) | |
| for sentence_id, neighbors in neighbors_map.items(): | |
| if sentence_id == neighbors[0]: | |
| top1 += 1 | |
| if sentence_id in neighbors[:5]: | |
| top5 += 1 | |
| n = len(sentence_embeddings[target_lang]) | |
| top1_str += f"{top1/n} " | |
| top1_str += "\n" | |
| print(top1_str) | |
| print(top1_str, file=open(f"{directory}/accuracy", "w")) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Analyze encoder outputs") | |
| parser.add_argument("directory", help="Source language corpus") | |
| parser.add_argument("--langs", help="List of langs") | |
| args = parser.parse_args() | |
| langs = args.langs.split(",") | |
| compute_accuracy(args.directory, langs) | |